{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This notebook is adapted from: \n",
    "https://github.com/ray-project/tutorial/tree/master/examples/sharded_parameter_server.ipynb\n",
    "\n",
    "# Sharded Parameter Servers\n",
    "\n",
    "**GOAL:** The goal of this exercise is to use actor handles to implement a sharded parameter server example for **distributed asynchronous stochastic gradient descent**.\n",
    "\n",
    "Before doing this exercise, make sure you understand the concepts from the exercise on **Actor Handles**.\n",
    "\n",
    "### Parameter Servers\n",
    "\n",
    "A parameter server is simply an object that stores the parameters (or \"weights\") of a machine learning model (this could be a neural network, a linear model, or something else). It exposes two methods: one for getting the parameters and one for updating the parameters.\n",
    "\n",
    "In a typical machine learning training application, worker processes will run in an infinite loop that does the following:\n",
    "1. Get the latest parameters from the parameter server.\n",
    "2. Compute an update to the parameters (using the current parameters and some data).\n",
    "3. Send the update to the parameter server.\n",
    "\n",
    "The workers can operate synchronously (that is, in lock step), in which case distributed training with multiple workers is algorithmically equivalent to serial training with a larger batch of data. Alternatively, workers can operate independently and apply their updates asynchronously. The main benefit of asynchronous training is that a single slow worker will not slow down the other workers. The benefit of synchronous training is that the algorithm behavior is more predictable and reproducible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np\n",
    "import ray\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Init SparkContext"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current pyspark location is : /root/anaconda2/envs/ray_train/lib/python3.6/site-packages/pyspark/__init__.py\n",
      "Start to pack current python env\n",
      "Collecting packages...\n",
      "Packing environment at '/root/anaconda2/envs/ray_train' to '/tmp/tmp7qvxc3o2/python_env.tar.gz'\n",
      "[########################################] | 100% Completed | 34.4s\n",
      "Packing has been completed: /tmp/tmp7qvxc3o2/python_env.tar.gz\n",
      "pyspark_submit_args is:  --master yarn --deploy-mode client --archives /tmp/tmp7qvxc3o2/python_env.tar.gz#python_env --num-executors 2  --executor-cores 4 --executor-memory 2g pyspark-shell \n"
     ]
    }
   ],
   "source": [
    "from zoo.common.nncontext import init_spark_on_local, init_spark_on_yarn\n",
    "import numpy as np\n",
    "import os\n",
    "hadoop_conf_dir = os.environ.get('HADOOP_CONF_DIR')\n",
    "\n",
    "if hadoop_conf_dir:\n",
    "    sc = init_spark_on_yarn(\n",
    "    hadoop_conf=hadoop_conf_dir,\n",
    "    conda_name=os.environ.get(\"ZOO_CONDA_NAME\", \"zoo\"), # The name of the created conda-env\n",
    "    num_executor=2,\n",
    "    executor_cores=4,\n",
    "    executor_memory=\"2g\",\n",
    "    driver_memory=\"2g\",\n",
    "    driver_cores=1,\n",
    "    extra_executor_memory_for_ray=\"3g\")\n",
    "else:\n",
    "    sc = init_spark_on_local(cores = 8, conf = {\"spark.driver.memory\": \"2g\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start to launch the JVM guarding process\n",
      "JVM guarding process has been successfully launched\n",
      "Start to launch ray on cluster\n",
      "Start to launch ray on local\n",
      "Executing command: ray start --redis-address 172.16.0.158:34046 --redis-password  123456 --num-cpus 0 --object-store-memory 400000000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2019-07-18 07:09:19,971\tWARNING worker.py:1341 -- WARNING: Not updating worker name since `setproctitle` is not installed. Install this with `pip install setproctitle` (or ray[debug]) to enable monitoring of worker processes.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "2019-07-18 07:09:19,855\tINFO services.py:409 -- Waiting for redis server at 172.16.0.158:34046 to respond...\n",
      "2019-07-18 07:09:19,858\tINFO scripts.py:363 -- Using IP address 172.16.0.102 for this node.\n",
      "2019-07-18 07:09:19,861\tINFO node.py:511 -- Process STDOUT and STDERR is being redirected to /tmp/ray/session_2019-07-18_15-09-10_137772_188428/logs.\n",
      "2019-07-18 07:09:19,862\tINFO services.py:1441 -- Starting the Plasma object store with 0.4 GB memory using /dev/shm.\n",
      "2019-07-18 07:09:19,887\tINFO scripts.py:371 -- \n",
      "Started Ray on this node. If you wish to terminate the processes that have been started, run\n",
      "\n",
      "    ray stop\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# It may take a while to ditribute the local environment including python and java to cluster\n",
    "import ray\n",
    "from zoo.ray import RayContext\n",
    "ray_ctx = RayContext(sc=sc, object_store_memory=\"4g\")\n",
    "ray_ctx.init()\n",
    "#ray.init(num_cpus=30, include_webui=False, ignore_reinit_error=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A simple parameter server can be implemented as a Python class in a few lines of code.\n",
    "\n",
    "**EXERCISE:** Make the `ParameterServer` class an actor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 10\n",
    "@ray.remote\n",
    "class ParameterServer(object):\n",
    "    def __init__(self, dim):\n",
    "        self.parameters = np.zeros(dim)\n",
    "    \n",
    "    def get_parameters(self):\n",
    "        return self.parameters\n",
    "    \n",
    "    def update_parameters(self, update):\n",
    "        self.parameters += update\n",
    "\n",
    "\n",
    "ps = ParameterServer.remote(dim)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A worker can be implemented as a simple Python function that repeatedly gets the latest parameters, computes an update to the parameters, and sends the update to the parameter server."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "@ray.remote\n",
    "def worker(ps, dim, num_iters):\n",
    "    for _ in range(num_iters):\n",
    "        # Get the latest parameters.\n",
    "        parameters = ray.get(ps.get_parameters.remote())\n",
    "        # Compute an update.\n",
    "        update = 1e-3 * parameters + np.ones(dim)\n",
    "        # Update the parameters.\n",
    "        ps.update_parameters.remote(update)\n",
    "        # Sleep a little to simulate a real workload.\n",
    "        time.sleep(0.5)\n",
    "\n",
    "# Test that worker is implemented correctly. You do not need to change this line.\n",
    "ray.get(worker.remote(ps, dim, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start two workers.\n",
    "worker_results = [worker.remote(ps, dim, 100) for _ in range(2)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As the worker tasks are executing, you can query the parameter server from the driver and see the parameters changing in the background."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[19.16281869 19.16281869 19.16281869 19.16281869 19.16281869 19.16281869\n",
      " 19.16281869 19.16281869 19.16281869 19.16281869]\n"
     ]
    }
   ],
   "source": [
    "print(ray.get(ps.get_parameters.remote()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sharding a Parameter Server\n",
    "\n",
    "As the number of workers increases, the volume of updates being sent to the parameter server will increase. At some point, the network bandwidth into the parameter server machine or the computation down by the parameter server may be a bottleneck.\n",
    "\n",
    "Suppose you have $N$ workers and $1$ parameter server, and suppose each of these is an actor that lives on its own machine. Furthermore, suppose the model size is $M$ bytes. Then sending all of the parameters from the workers to the parameter server will mean that $N * M$ bytes in total are sent to the parameter server. If $N = 100$ and $M = 10^8$, then the parameter server must receive ten gigabytes, which, assuming a network bandwidth of 10 giga*bits* per second, would take 8 seconds. This would be prohibitive.\n",
    "\n",
    "On the other hand, if the parameters are sharded (that is, split) across `K` parameter servers, `K` is `100`, and each parameter server lives on a separate machine, then each parameter server needs to receive only 100 megabytes, which can be done in 80 milliseconds. This is much better.\n",
    "\n",
    "**EXERCISE:** The code below defines a parameter server shard class. Modify this class to make `ParameterServerShard` an actor. We will need to revisit this code soon and increase `num_shards`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "@ray.remote\n",
    "class ParameterServerShard(object):\n",
    "    def __init__(self, sharded_dim):\n",
    "        self.parameters = np.zeros(sharded_dim)\n",
    "    \n",
    "    def get_parameters(self):\n",
    "        return self.parameters\n",
    "    \n",
    "    def update_parameters(self, update):\n",
    "        self.parameters += update\n",
    "\n",
    "\n",
    "total_dim = (10 ** 8) // 8  # This works out to 100MB (we have 25 million\n",
    "                            # float64 values, which are each 8 bytes).\n",
    "num_shards = 2  # The number of parameter server shards.\n",
    "\n",
    "assert total_dim % num_shards == 0, ('In this exercise, the number of shards must '\n",
    "                                     'perfectly divide the total dimension.')\n",
    "\n",
    "# Start some parameter servers.\n",
    "ps_shards = [ParameterServerShard.remote(total_dim // num_shards) for _ in range(num_shards)]\n",
    "\n",
    "assert hasattr(ParameterServerShard, 'remote'), ('You need to turn ParameterServerShard into an '\n",
    "                                                 'actor (by using the ray.remote keyword).')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code below implements a worker that does the following.\n",
    "1. Gets the latest parameters from all of the parameter server shards.\n",
    "2. Concatenates the parameters together to form the full parameter vector.\n",
    "3. Computes an update to the parameters.\n",
    "4. Partitions the update into one piece for each parameter server.\n",
    "5. Applies the right update to each parameter server shard."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "@ray.remote\n",
    "def worker_task(total_dim, num_iters, *ps_shards):\n",
    "    # Note that ps_shards are passed in using Python's variable number\n",
    "    # of arguments feature. We do this because currently actor handles\n",
    "    # cannot be passed to tasks inside of lists or other objects.\n",
    "    for _ in range(num_iters):\n",
    "        # Get the current parameters from each parameter server.\n",
    "        parameter_shards = [ray.get(ps.get_parameters.remote()) for ps in ps_shards]\n",
    "        assert all([isinstance(shard, np.ndarray) for shard in parameter_shards]), (\n",
    "               'The parameter shards must be numpy arrays. Did you forget to call ray.get?')\n",
    "        # Concatenate them to form the full parameter vector.\n",
    "        parameters = np.concatenate(parameter_shards)\n",
    "        assert parameters.shape == (total_dim,)\n",
    "\n",
    "        # Compute an update.\n",
    "        update = np.ones(total_dim)\n",
    "        # Shard the update.\n",
    "        update_shards = np.split(update, len(ps_shards))\n",
    "        \n",
    "        # Apply the updates to the relevant parameter server shards.\n",
    "        for ps, update_shard in zip(ps_shards, update_shards):\n",
    "            ps.update_parameters.remote(update_shard)\n",
    "\n",
    "\n",
    "# Test that worker_task is implemented correctly. You do not need to change this line.\n",
    "ray.get(worker_task.remote(total_dim, 1, *ps_shards))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**EXERCISE:** Experiment by changing the number of parameter server shards, the number of workers, and the size of the data.\n",
    "\n",
    "**NOTE:** Because these processes are all running on the same machine, network bandwidth will not be a limitation and sharding the parameter server will not help. To see the difference, you would need to run the application on multiple machines. There are still regimes where sharding a parameter server can help speed up computation on the same machine (by parallelizing the computation that the parameter server processes have to do). If you want to see this effect, you should implement a synchronous training application. In the asynchronous setting, the computation is staggered and so speeding up the parameter server usually does not matter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "This took 4.21185827255249 seconds.\n"
     ]
    }
   ],
   "source": [
    "num_workers = 4\n",
    "\n",
    "# Start some workers. Try changing various quantities and see how the\n",
    "# duration changes.\n",
    "start = time.time()\n",
    "ray.get([worker_task.remote(total_dim, 5, *ps_shards) for _ in range(num_workers)])\n",
    "print('This took {} seconds.'.format(time.time() - start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (ray_train)",
   "language": "python",
   "name": "ray_train"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}